import logging
import argparse

import torch
from matplotlib import pyplot as plt
import numpy as np
import scipy.io as sio

from data import Dataset
from network.Network import ResNet18ARM


def parse():
    parser = argparse.ArgumentParser(description='My Awesome FER')
    parser.add_argument('--cuda', default=True, type=bool, help='enable cuda or not')

    parser.add_argument('--test_batch_size', default=64, type=int,
                        help='batch size of testing')
    parser.add_argument('--test_set', default='X:\\combine.txt', type=str,
                        help='path to the index')
    parser.add_argument('--model', default='X:\\epoch55_acc0.920.pth', type=str,
                        help='path to the model')
    parser.add_argument('--mat', default='X:\\result.mat', type=str,
                        help='path to the mat')

    return parser.parse_args()


def main():
    arg = parse()

    # Prepare Data
    logging.info('[Pre] Preparing Data')
    test_dataloader = Dataset.prepare_test(arg)
    class_names = ['Surprise', 'Fear', 'Disgust', 'Happiness', 'Sadness', 'Anger', 'Neutral']

    # Build Network
    logging.info('[Pre] Building Network')
    model = ResNet18ARM(pretrained=True, num_classes=7)
    checkpoint = torch.load(arg.model)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Cuda
    if arg.cuda:
        model.cuda()

    result = {}
    # loop
    with torch.no_grad():
        for batch, (images, _, shots) in enumerate(test_dataloader):
            X = images.cuda()
            pred, _ = model(X)

            _, predicts = torch.max(pred, 1)
            labels = predicts.cpu().tolist()

            for i in range(len(images)):
                shot = 'shot' + shots[i]
                if shot not in result:
                    result[shot] = [labels[i]]
                else:
                    result[shot].append(labels[i])

            logging.info(f'[Predict] batch: {batch}')
            '''
            plt.figure(figsize=(15, 15))
            for i in range(len(images)):
                plt.subplot(8, 8, i + 1)
                plt.xticks([])
                plt.yticks([])
                plt.grid(False)
                plt.imshow(raw_image[i])
                plt.xlabel(class_names[labels[i]])
            plt.show()
            '''
    # results[shot][name] = label
    result_max = {}
    for k in result:
        result_max[k] = max(result[k], key=result[k].count)
    k = np.fromiter(result_max.keys(), dtype='U12')
    v = np.fromiter(result_max.values(), dtype='i4')
    result_mat = {'label': v, 'shot': k}
    sio.savemat(arg.mat, result_mat)
    '''
    In matlab, do:
        B = [0 1 2 3 4 5 6]
        C = [2 1 0 3 1 2 0]
        [~,idx] = ismember(label,B)
        A = C(idx)
        label = int64(A)
        clear A B C idx
        save converted_stage1.mat label shot
    '''


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)
    main()
